'''
Supervised Instruction Tuning 
'''

import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "5,6"
import argparse
import json
from accelerate import Accelerator
from datasets import load_dataset, Dataset, DatasetDict, load_from_disk
from peft import LoraConfig, AdaLoraConfig, LoHaConfig, IA3Config, prepare_model_for_kbit_training
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, logging, set_seed
from trl import SFTTrainer
import datetime
from tools.prompt import expert_prompt_text
from sklearn.model_selection import train_test_split

data_path_config = {
    "StackMIA": "/mnt/sharedata/ssd/users/zhanghx/dataset/benchmark/StackMIAsub",
    "wikiMIA_2023": "/data/home/zhanghx/code/DataContaminate/benchmarks/fine_tuning/train_data.json",
    "wikiMIA_ontime": "/data/home/zhanghx/code/DataContaminate/benchmarks/no_time/train_data.json",
    "gsm8k": "/data/home/zhanghx/code/DataContaminate/benchmarks/contam-1.4b/gsm8k.jsonl",
    "BookMIA": "/mnt/sharedata/ssd/users/zhanghx/dataset/benchmark/BookMIA",
    "BookTection": "/mnt/sharedata/ssd/users/zhanghx/dataset/benchmark/BookTection",
    "arXivTection": "/mnt/sharedata/ssd/users/zhanghx/dataset/benchmark/arXivTection",
    "tofu": "/mnt/sharedata/ssd/users/zhanghx/dataset/benchmark/tofu",
}

model_path_config = {
    "llama-7b": "/mnt/sharedata/ssd/users/zhanghx/models/llama-7b",
    "llama-13b": "/mnt/sharedata/ssd/users/zhanghx/models/llama-13b",
    "llama-30b": "/mnt/sharedata/ssd/users/zhanghx/models/llama-30b",
    "llama-65b": "/mnt/sharedata/ssd/users/zhanghx/models/llama-65b",
    "llama-2-7b": "/mnt/sharedata/ssd/users/zhanghx/models/llama-2-7b-hf",
    "pythia-6.9b": "/mnt/sharedata/ssd/users/zhanghx/models/pythia-6.9b"
}

def get_dataset(args):
    dataset_name = args.dataset.split('/')[-1]
    if dataset_name == 'StackMIA':
        ds = load_from_disk(data_path_config["StackMIA"])
        ds = ds.shuffle(args.seed)
        train_test_data = ds.train_test_split(train_size=args.train_size, seed=args.seed)
        train_ds = train_test_data['train'] # less data
        format_dataset = {'text': [], 'label': []}
        for i in range(len(train_ds)):
            format_dataset['text'].append(train_ds[i]['snippet'])
            format_dataset['label'].append(train_ds[i]['label'])
        dataset = Dataset.from_dict(format_dataset)
    elif dataset_name == 'wikiMIA_ontime':
        dataset = get_dataset_json(data_path_config["wikiMIA_ontime"])
    elif dataset_name == 'wikiMIA_2023':
        dataset = get_dataset_json(data_path_config["wikiMIA_2023"])
    elif dataset_name ==  'gsm8k':
        with open(data_path_config['gsm8k'], "r") as f:
            data = f.readlines() # list
            format_dataset = {'text': [], 'label': []}
            for i in range(len(data)):
                format_dataset['text'].append(data[i])
                format_dataset['label'].append(1)
        train_text, test_text, train_label, test_label = train_test_split(
            format_dataset['text'], format_dataset['label'], test_size=0.3, random_state=args.seed)
        train_dataset = Dataset.from_dict({'text': test_text, 'label': test_label})
        dataset = train_dataset
    elif dataset_name == 'WikiMIA':
        ds = load_from_disk("/mnt/sharedata/ssd/users/zhanghx/dataset/benchmark/WikiMIA")
        length = [32, 64, 128, 256]
        format_dataset = {'text': [], 'label': []}
        for l in length:
            # dataset = load_dataset("swj0419/WikiMIA", split=f"WikiMIA_length{l}")
            dataset = ds[f"WikiMIA_length{l}"]
            dataset.shuffle(seed=args.seed)
            split = dataset.train_test_split(train_size=args.train_size, seed=args.seed)
            dataset = split['train']
            for i in range(len(dataset)):
                format_dataset['text'].append(dataset[i]['input'])
                format_dataset['label'].append(dataset[i]['label'])        
        dataset = Dataset.from_dict(format_dataset)
    elif dataset_name == 'BookMIA':
        ds = load_from_disk(data_path_config["BookMIA"])
        ds = ds['train']
        ds = ds.shuffle(args.seed)
        train_test_split = ds.train_test_split(train_size=args.train_size, seed=args.seed)
        dataset = train_test_split['train']
        format_dataset = {'text': [], 'label': []}
        for i in range(len(dataset)):
            format_dataset['text'].append(dataset[i]['snippet'])
            format_dataset['label'].append(dataset[i]['label'])
        dataset = Dataset.from_dict(format_dataset)
    elif dataset_name == 'BookTection' or dataset_name == 'arXivTection':
        ds = load_from_disk(data_path_config[dataset_name])
        ds = ds['train']
        ds = ds.shuffle(args.seed)
        train_test_split = ds.train_test_split(train_size=args.train_size, seed=args.seed)
        dataset = train_test_split['train']
        format_dataset = {'text': [], 'label': []}
        for i in range(len(dataset)):
            format_dataset['text'].append(dataset[i]['Example_A'])
            format_dataset['label'].append(dataset[i]['Label'])
        dataset = Dataset.from_dict(format_dataset)
    elif dataset_name == 'tofu':
        ds = load_from_disk(data_path_config["tofu"])
        ds = ds['train']
        ds = ds.shuffle(args.seed)
        train_test_split = ds.train_test_split(train_size=args.train_size, seed=args.seed)
        dataset = train_test_split['train']
        format_dataset = {'text': [], 'label': []}
        for i in range(len(dataset)):
            format_dataset['text'].append(dataset[i]['question'])
            format_dataset['label'].append(2) # 2 represents the label is unknown
        dataset = Dataset.from_dict(format_dataset)
        
    # this is a test code for remove the sample with special ID
    # elif dataset_name == 'BookTection' or dataset_name == 'arXivTection':
    #     ds = load_from_disk(data_path_config[dataset_name])
    #     ds = ds['train']
    #     ds = ds.shuffle(args.seed)
    #     train_test_split = ds.train_test_split(train_size=args.train_size, seed=args.seed)
    #     dataset = train_test_split['train']
    #     format_dataset = {'text': [], 'label': []}
    #     book_name = ["After_Death_-_Dean_Koontz", "In_the_Silence_of_Decay_-_Lisa_Boyle", "A_Living_Remedy_-_Nicole_Chung"]
    #     for i in range(len(dataset)):
    #         if(dataset[i]['ID'] in book_name):
    #             continue
    #         format_dataset['text'].append(dataset[i]['Example_A'])
    #         format_dataset['label'].append(dataset[i]['Label'])
    #     dataset = Dataset.from_dict(format_dataset)
    
    else:
        raise ValueError("Unsupported dataset")
    print("dataset_nums: ", len(dataset))
    labels = []
        
    sampel_data = {'text': [], 'label': []}
    counter = 0
    for ex in dataset:
        labels.append(ex['label'])
        if ex['label'] == 0:
            counter += 1
            sampel_data['text'].append(ex['text'])
            sampel_data['label'].append(ex['label'])
        if counter == 100:
            break
    assert len(sampel_data['text']) == 100, "The number of non-member is less than 100"
    sampel_data = Dataset.from_dict(sampel_data)
    print("member nums: ", labels.count(1))
    print("non-member nums: ", labels.count(0))
    print("all nums: ", len(labels))
    return sampel_data
  
def get_dataset_json(path):
    with open(path, 'r') as f:
        data = json.load(f)
    format_dataset = {'text': [], 'label': [], 'length': []}
    for i in range(len(data)):
        format_dataset['text'].append(data[i]['text'])
        format_dataset['label'].append(data[i]['label'])
        format_dataset['length'].append(data[i]['length'])
    dataset = Dataset.from_dict(format_dataset)
    return dataset

def prompts_complete(examples):
    
    output_text = []
    for i in range(len(examples["text"])):
        
        text = examples["text"][i]
        words = text.split()
        half_length = len(words) // 2
        input_text = " ".join(words[:half_length])
        Response = " ".join(words[half_length:])
        text = f'''Below is an  an incomplete input, Write a response that appropriately completes the input.
        
        ### Input:
        {input_text}
        
        ### Response:
        {Response}
        '''
        output_text.append(text)
    return output_text

def ppl_prompts(examples):
    output_text = []
    for i in range(len(examples["text"])):
        input_text = examples["text"][i]
        label = examples["label"][i]
        if label == 1:
            answer = "Yes"
        else:
            answer = "No"
            
        text = f'''
        Below is an input may be from pre-training corpus. if the input comes from the pre-training corpora, the answer is "Yes", otherwise, it is "No". Please provide an answer.
        
        input: {input_text}
        
        answer: {answer}
        '''
        output_text.append(text)
        
        # print(text)
    return output_text

def expert_prompts(examples):
    
    output_text = []
    for i in range(len(examples["text"])):
        
        input_text = examples["text"][i]
        label = examples["label"][i]
        if label == 1:
            answer = "Yes"
        else:
            answer = "No"
            
        text = f'''{expert_prompt_text}
        
        ### Input:
        {input_text}
        
        ### answer:
        {answer}
        '''
            
        output_text.append(text)
    return output_text

def general_prompts(examples):
    
    output_text = []
    counter = 0
    for i in range(len(examples["text"])):
        if(examples["label"][i] == 1): # !only choose members or non-member to finetune pre-trained model
            continue
        counter += 1
        input_text = examples["text"][i]
        text = input_text        
        output_text.append(text)
    #     if counter == 10:
    #         break
    # if(counter != 10):
    #     # raise ValueError("The number of non-member is less than 100")
    print("overall numbers: ", len(examples['text']))
    print("fine-tuning dataset nums: ", len(output_text))
    return output_text

def format_prompts(examples):
    
    output_text = []
    for i in range(len(examples["text"])):
        
        input_text = examples["text"][i]
        label = examples["label"][i]
        if label == 1:
            answer = "Yes"
            response = input_text
        else:
            answer = "No"
            response= ""
            
        text = f'''Below is an input may be from pre-training corpus. if the input is seen in the pre-training step, the answer is "Yes", otherwise, it is "No". Please provide an answer. 
        
        ### Input:
        {input_text}
        
        ### answer:
        {answer}
        '''
            
        output_text.append(text)
    return output_text

prompt_dict = {
    "complete": prompts_complete,
    "answer": format_prompts,
    "expert": expert_prompts,
    "perplexity": ppl_prompts,
    "general": general_prompts
}
  
def print_trainable_params(model):
    """
    Print the number of trainable parameters in the model
    """
    trainable_params = 0
    all_params = 0
    for _, param in model.named_parameters():
        all_params += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    
    print(
        f"trainable params: {trainable_params} || all params: {all_params} || trainable%: {100 * trainable_params / all_params}"
    )


def run_training(args, train_data, tokenzier):
    print('Loading the model')

    lora_config = LoraConfig(
        r=16,
        lora_alpha=32,
        lora_dropout=0.05,
        bias='none',
        task_type='CAUSAL_LM'
    )
    adalora_config =  AdaLoraConfig(
        # peft_type="ADALORA",
        r=16,
        lora_alpha=32,
        # target_modules=["query", "value"],
        lora_dropout=0.05,
        bias='none',
        # task_type='CAUSAL_LM'
    )
    
    ia3_config = IA3Config(
        task_type='CAUSAL_LM',
    )
    tuner_dict = {
        "lora": lora_config,
        "adalora": adalora_config,
        }
    
    train_data.start_iteration = 0

    print('Starting main loop')

    training_args = TrainingArguments(
        output_dir=args.ckpts_dir,
        dataloader_drop_last=True,
        max_steps=args.max_steps,
        num_train_epochs=args.num_train_epochs,
        evaluation_strategy="epoch",
        save_strategy="no",  # 按epoch保存模型
        #save_steps=args.save_freq,
        logging_steps=args.log_freq,
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        learning_rate=args.learning_rate,
        lr_scheduler_type=args.lr_scheduler_type,
        warmup_steps=args.num_warmup_steps,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        gradient_checkpointing=args.gradient_checkpointing,
        fp16=args.fp16,
        bf16=args.bf16,
        weight_decay=args.weight_decay,
        run_name=args.run_name,
        report_to="none",
        ddp_find_unused_parameters=False,
        #save_total_limit=3,
        # max_grad_norm=1.0  # 添加梯度裁剪
        # gradient_clip_val=2.0
    )

    model = AutoModelForCausalLM.from_pretrained(
        args.model_id, load_in_8bit=True, device_map={"": Accelerator().process_index}
    )
    
    model.gradient_checkpointing_enable()
    model.config.use_cache = False

    trainer = SFTTrainer(
        model=model,
        tokenizer=tokenzier,
        max_seq_length=args.seq_length,
        args=training_args,
        eval_dataset=train_data, # use the training data as the evaluation dataset
        train_dataset=train_data,
        # peft_config=lora_config,
        peft_config= tuner_dict[args.tuner],
        formatting_func=prompt_dict[args.prompt],
        packing=False
    )

    print_trainable_params(trainer.model)

    print("Training...")
    trainer.train()
    return trainer.model


def main(args):
    print(args.model_id)
    timestamp = datetime.datetime.now().strftime("%m-%d-%H:%M:%S")
    exp_dir = os.path.join(args.ckpts_dir, args.model_id.split('/')[-1], 'seed_'+str(args.seed), args.dataset + "-" +timestamp)
    os.makedirs(exp_dir, exist_ok=True)
    train_dataset = get_dataset(args)
    tokenizer = AutoTokenizer.from_pretrained(args.model_id) 
    tokenizer.pad_token = tokenizer.eos_token #

    model = run_training(args, train_dataset, tokenizer)

    print("Saving last checkpoint of the model")
    model.save_pretrained(exp_dir)


if __name__ == '__main__':

    parser = argparse.ArgumentParser('Supervised Fintuning with PEFT')
    parser.add_argument("--seed", type=int, default=42)
    # "HuggingFaceH4/ultracha t_200k", "databricks/databricks-dolly-15k"  
    #parser.add_argument('--dataset', type=str, default="yahma/alpaca-cleaned", choices=["yahma/alpaca-cleaned", "MaziyarPanahi/WizardLM_evol_instruct_V2_196k", "mosaicml/dolly_hhrlhf", "64bits/lima_vicuna_format"]) 
    parser.add_argument('--dataset', type=str, default="wikiMIA", choices=["StackMIA", "tofu", "wikiMIA_ontime", "gsm8k", "WikiMIA", "BookMIA", "BookTection", "arXivTection", "wikiMIA_2023"]) 
    
    parser.add_argument('--split', type=str, default='train') # w.o.t validation splitting
    # parser.add_argument("--ckpts_dir", type=str, default="./ckpts")
    parser.add_argument("--ckpts_dir", type=str, default="/mnt/sharedata/ssd/users/zhanghx/save_data/checkpoint")  
    #parser.add_argument('--model_id', type=str, default="meta-llama/Meta-Llama-3-8B", choices=["google/gemma-2b", "mistralai/Mistral-7B-v0.3", "meta-llama/Meta-Llama-3-8B"])
    parser.add_argument('--model_id', type=str, default="/mnt/sharedata/ssd/users/zhanghx/models/llama-7b")
    
    parser.add_argument('--seq_length', type=int, default=1024)
    parser.add_argument('--num_train_epochs', type=int, default=3)
    parser.add_argument('--max_steps', type=int, default=-1)
    parser.add_argument('--batch_size', type=int, default=8)
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1)

    parser.add_argument("--learning_rate", type=float, default=1e-3)
    parser.add_argument("--lr_scheduler_type", type=str, default="cosine")
    parser.add_argument("--num_warmup_steps", type=int, default=100)
    parser.add_argument("--weight_decay", type=float, default=0.0005)

    parser.add_argument('--fp16', action='store_true', default=False)
    parser.add_argument('--bf16', action='store_true', default=False)
    parser.add_argument("--gradient_checkpointing", action="store_true", default=False)
    parser.add_argument("--num_workers", type=int, default=8)
    parser.add_argument("--log_freq", default=1, type=int)
    parser.add_argument("--save_freq", default=5000, type=int)
    parser.add_argument("--run_name", type=str, default="llama-7b-finetuned")

    parser.add_argument("--prompt", type=str, default="answer", choices=["complete", "answer", "expert", "perplexity", "general"])
    parser.add_argument("--tuner", type=str, default="lora", choices=["lora", "adalora"])
    
    parser.add_argument("--train_size", type=float, default=0.3)
    args = parser.parse_args()
    
    set_seed(args.seed)
    # logging.set_verbosity_info()
    logging.set_verbosity_error()

    main(args)
    